home *** CD-ROM | disk | FTP | other *** search
/ Tech Arsenal 1 / Tech Arsenal (Arsenal Computer).ISO / tek-04 / nasanets.zip / NETMAIN.C < prev    next >
C/C++ Source or Header  |  1990-06-07  |  21KB  |  597 lines

  1. /*=============================*/
  2. /*           NETS              */
  3. /*                             */
  4. /* a product of the AI Section */
  5. /* NASA, Johnson Space Center  */
  6. /*                             */
  7. /* principal author:           */
  8. /*       Paul Baffes           */
  9. /*                             */
  10. /* contributing authors:       */
  11. /*      Bryan Dulock           */
  12. /*      Chris Ortiz            */
  13. /*=============================*/
  14.  
  15.  
  16. /*
  17. ----------------------------------------------------------------------
  18.   This is the main module for the neural net.  All of the main code   
  19.   plus documentation on the system features lives here.
  20.  
  21.   A note on nomenclature: Most of the routine calls which you will 
  22.   see in the code throughout these files will be "prefixed" by some 
  23.   letters and an underscore.  That is one of my conventions for indicating
  24.   which file contains the code for the subroutine in question. Because
  25.   I have several files, this makes tracing problems and debugging
  26.   files easier to accomplish.  
  27.  
  28.   This file is one of 14 source files which make up the back
  29.   propagation code. These files are the following:
  30.  
  31.       activate.c              semi-linear activation function
  32.       buildnet.c
  33.       compile.c
  34.       convert.c               conversion routines, conversion routines
  35.       layer.c                 layer manipulation and creation
  36.       lnrate.c
  37.       net.c                   net manipulation, learning, propagation
  38.       netio.c                 I/O routines, file handlers
  39.       netmain.c               main routines, menus, user interface
  40.       pairs.c
  41.       prop.c
  42.       shownet.c
  43.       teach.c
  44.       weights.c               weights manipulation and creation
  45.  
  46.   All of these are covered in detail in the respective files. The prefix
  47.   codes for each of the files are as follows:
  48.  
  49.       activate.c              "A_"
  50.       buildnet.c              "B_"
  51.       compile.c               "CC_"
  52.       convert.c               "C_"
  53.       dribble.c               "D_"
  54.       net.c                   "N_"
  55.       netio.c                 "IO_"
  56.       netmain.c               *NONE*
  57.       pairs.c                 "PA_"
  58.       parser.c                "PS_"
  59.       prop.c                  "P_"
  60.       layer.c                 "L_"
  61.       lnrate.c                "LR_"
  62.       shownet.c               "S_"
  63.       teach.c                 "T_"
  64.       weights.c               "W_"
  65.       sysdep.c                "sys_"    (system dependent code)
  66.  
  67.   The rest of this file is organized into the folloing groups:
  68.  
  69.   (1) include files
  70.   (2) externed functions
  71.   (3) global variables
  72.   (4) subroutines
  73. ----------------------------------------------------------------------
  74. */
  75.  
  76.  
  77. /*
  78. ----------------------------------------------------------------------
  79.   INCLUDE FILES
  80. ----------------------------------------------------------------------
  81. */
  82. #include  "common.h"
  83. #include  "netio.h"
  84. #include  "weights.h"
  85. #include  "layer.h"
  86. #include  "net.h"
  87.  
  88.  
  89. /*
  90. ----------------------------------------------------------------------
  91.   EXTERNED FUNCTIONS
  92. ----------------------------------------------------------------------
  93. */
  94. extern Net    *B_create_net();
  95. extern Net    *B_free_net();
  96.  
  97. extern void   N_query_net();
  98. extern int    N_reset_wts();
  99. extern void   N_save_wts();
  100. extern Layer  *N_get_layer();
  101.  
  102. extern void   T_teach_net();
  103.  
  104. extern void   S_show_weights();
  105. extern void   S_show_biases();
  106. extern void   S_show_net();
  107.  
  108. extern float  IO_my_get_float();
  109. extern int    IO_my_get_int();
  110. extern int    IO_get_default_int();
  111. extern int    IO_get_num_cycles();
  112. extern void   IO_my_get_string();
  113. extern void   IO_set_filenames();
  114. extern void   IO_get_io_name();
  115. extern void   IO_get_wts_name();
  116. extern void   IO_print();
  117.  
  118. extern void   PA_initialize();
  119. extern void   PA_setup_iopairs();
  120. extern void   PA_reset_iopairs();
  121. extern void   PA_randomize_file();
  122.  
  123. extern void   D_initialize();
  124. extern void   D_dribble_status(); 
  125.  
  126. extern Sint   C_float_to_Sint();
  127. extern void   L_modify_learning();
  128. extern void   sys_init_rand();
  129. extern void   CC_create_delivery();
  130.  
  131.  
  132. /*
  133. ----------------------------------------------------------------------
  134.   GLOBAL VARIABLES
  135. ----------------------------------------------------------------------
  136.  Next come the global variables, declared in other routines, which    
  137.   need to be referenced here.  In general I tried to keep the number  
  138.   of globals to a minimum since they can be messy, but many of the    
  139.   io functions needed to keep "state" variables for the lifetime of   
  140.   the program execution.  Examples are the default file names which   
  141.   are used when prompting the user.  These names are read here, and   
  142.   referenced in  net.c as well as netio.c, and thus I needed to be    
  143.   able to pass them around.  I could have simply left them as global  
  144.   values and referenced them as needed, but instead I passed them     
  145.   explicitly to the non-io routines that needed them.  I could just   
  146.   as easily have declared them here and made them external to the     
  147.   io package, but since file names and IO are so intimately tied, I   
  148.   thought it more logical to declare these variables with the other   
  149.   IO code.                                                            
  150. ----------------------------------------------------------------------
  151. */
  152. extern char  net_config[];                 /* these three from netio.c  */
  153. extern char  net_iop[];
  154. extern char  net_fwt[];
  155. extern char  net_pwt[];
  156. extern char  IO_str[MAX_LINE_SIZE];
  157.  
  158. static Net  *the_net;          /* variable which holds ptr to the net   */
  159.                                /* (global to this netmain routine only) */
  160.                                
  161.  
  162. /*
  163. ======================================================================
  164.   ROUTINES IN NETMAIN.C
  165. ======================================================================
  166.   The routines in this file are grouped below by function.  NO ROUTINE
  167.   ARE PREFIXED IN THIS FILE. 
  168.  
  169.   The main routine is just below.  It basically amounts to a read-eval-
  170.   print loop, much like an interpreter. After initialization, this code
  171.   loops forever, calling "print_menu", "read_choice", and then 
  172.   "evaluate".  The only other routine is a "check_net_ptr" routine 
  173.   used during evaluate to verify that a valid net exists before 
  174.   attempting some operation.
  175. ======================================================================
  176. */
  177.  
  178.  
  179. main()
  180. BEGIN
  181.    void  initialize(), print_menu(), evaluate(), cleanup();
  182.    char  read_choice();
  183.    char  c;
  184.       
  185.    initialize();
  186.    print_menu();
  187.    while (TRUE) BEGIN
  188.       c = read_choice();
  189.       if (c == 'q') BEGIN
  190.          cleanup();
  191.          break;
  192.       ENDIF
  193.       evaluate(c);
  194.    ENDWHILE
  195.  
  196. END /* main */
  197.  
  198.  
  199. void initialize()
  200. /*
  201. -----------------------------------------------------------------
  202.  This is the initialization routine for the main program. Much   
  203.   of what the program is able to do is assumed here. For example 
  204.   I have assumed (for now) that we will be dealing with only one 
  205.   net at a time (even though I have made allowances in the NET   
  206.   structure for multiple nets).  Thus, this guy only initializes 
  207.   the one global 'the_net' variable.  This would need changing   
  208.   in the future if more nets were added.                         
  209.  Note also that the system random number generator is setup here 
  210.   This is significant since IT MAY NOT BE PORTABLE.              
  211. -----------------------------------------------------------------
  212. */
  213. BEGIN
  214.    the_net = NULL;
  215.    sys_init_rand();
  216.    PA_initialize();
  217.    D_initialize();
  218.    
  219. END /* initialize */
  220.  
  221.  
  222. void print_menu()
  223. /*
  224. ----------------------------------------------------------------------
  225.  This routine prints out the variety of choices available to the user 
  226.   for operating on/creating his net.                                  
  227. ----------------------------------------------------------------------
  228. */
  229. BEGIN
  230.    sprintf(IO_str, "\n\n");
  231.    IO_print(0);
  232.    sprintf(IO_str, "NASA-JSC Artificial Intelligence Section\n");
  233.    IO_print(0);
  234.    sprintf(IO_str, "NETS Back Propagation Simulator Version 2.0\n\n");
  235.    IO_print(0);
  236.    sprintf(IO_str, "b -- show bias values\n");
  237.    IO_print(0);
  238.    sprintf(IO_str, "c -- create a net\n");
  239.    IO_print(0);
  240.    sprintf(IO_str, "d -- set dribble parameters\n");
  241.    IO_print(0);
  242.    sprintf(IO_str, "g -- generate delivery code\n");
  243.    IO_print(0);
  244.    sprintf(IO_str, "i -- setup I/O pairs\n");
  245.    IO_print(0);
  246.    sprintf(IO_str, "j -- reset I/O pairs\n");
  247.    IO_print(0);
  248.    sprintf(IO_str, "l -- change learning rate\n");
  249.    IO_print(0);
  250.    sprintf(IO_str, "m -- print this menu\n");
  251.    IO_print(0);
  252.    sprintf(IO_str, "n -- show net configuration\n");
  253.    IO_print(0);
  254.    sprintf(IO_str, "o -- reorder I/O pairs for training\n");
  255.    IO_print(0);
  256.    sprintf(IO_str, "p -- propagate an input through the net\n");
  257.    IO_print(0);
  258.    sprintf(IO_str, "r -- reset weights from a file\n");
  259.    IO_print(0);
  260.    sprintf(IO_str, "s -- save weights to a file\n");
  261.    IO_print(0);
  262.    sprintf(IO_str, "t -- teach the net\n");
  263.    IO_print(0);
  264.    sprintf(IO_str, "w -- show weights between two layers \n");
  265.    IO_print(0);
  266.    sprintf(IO_str, "q -- quit program\n");
  267.    IO_print(0);
  268.  
  269. END /* print_menu */
  270.  
  271.  
  272. char  read_choice()
  273. /*
  274. ----------------------------------------------------------------------
  275.  Now, this routine will seem simple, but I wanted to separate it out  
  276.   in the event that the future brings more sophisticated input.  Also 
  277.   this program needs to make sure that the input read in is legal,    
  278.   which is really not part of the print_menu function at all.         
  279. ----------------------------------------------------------------------
  280. */
  281. BEGIN
  282.    char  result[MAX_LINE_SIZE];
  283.  
  284.    sprintf(IO_str, "\nNETS Choice(m = menu)? ");
  285.    IO_print(0);
  286.    gets(result);
  287.    while (result[0] != 'b' 
  288.           && result[0] != 'c' 
  289.           && result[0] != 'd' 
  290.           && result[0] != 'g' 
  291.           && result[0] != 'i' 
  292.           && result[0] != 'j' 
  293.           && result[0] != 'l' 
  294.           && result[0] != 'm' 
  295.           && result[0] != 'n' 
  296.           && result[0] != 'o' 
  297.           && result[0] != 'p' 
  298.           && result[0] != 'r' 
  299.           && result[0] != 's' 
  300.           && result[0] != 't' 
  301.           && result[0] != 'w' 
  302.           && result[0] != 'q') BEGIN
  303.       sprintf(IO_str, "\nSorry, I don't understand that input.\n");
  304.       IO_print(0);
  305.       sprintf(IO_str, "Please try again: ");
  306.       IO_print(0);
  307.       gets(result);
  308.    ENDWHILE
  309.    return(result[0]);
  310.  
  311. END /* read_choice */
  312.  
  313.  
  314. void  evaluate(input)
  315. char  input;
  316. /*
  317. ----------------------------------------------------------------------
  318.  Evaluates the input command and calls the appropriate function to    
  319.   carry out the actions                                               
  320. ----------------------------------------------------------------------
  321. */
  322. BEGIN
  323.    int    t1, t2, check_net_ptr();
  324.    float  tf1;
  325.    char   filename[MAX_WORD_SIZE], file2[MAX_WORD_SIZE], 
  326.           tmp_str[MAX_WORD_SIZE];
  327.    FILE   *fp; /* for query from file */
  328.  
  329.    switch (input) BEGIN
  330.       case 'b': BEGIN
  331.          if (check_net_ptr(the_net) == OK) BEGIN
  332.             if (the_net->use_biases == FALSE) BEGIN
  333.                sprintf(IO_str, "\n*** network has no biases");
  334.                IO_print(0);
  335.             ENDIF
  336.             else BEGIN
  337.                sprintf(IO_str, "\n   layer number? ");
  338.                IO_print(0);
  339.                t1 = IO_my_get_int();
  340.                S_show_biases(the_net, t1);
  341.             ENDELSE
  342.          ENDIF
  343.          break;
  344.       ENDCASE
  345.       case 'c': BEGIN
  346.          the_net = B_free_net(the_net);
  347.          IO_set_filenames();
  348.          the_net = B_create_net(1, net_config);
  349.          break;
  350.       ENDCASE
  351.       case 'd': BEGIN
  352.          D_dribble_status();
  353.          break;
  354.       ENDCASE
  355.       case 'g': BEGIN
  356.          if (check_net_ptr(the_net) == OK) BEGIN
  357.             /*--------------------------------*/
  358.             /* use the net_config filename as */
  359.             /* the configuration file.        */
  360.             /*--------------------------------*/
  361.              strcpy(filename, net_config);
  362.             /*-----------------------------*/
  363.             /* prompt for the weights file */
  364.             /*-----------------------------*/
  365.             sprintf(IO_str, "\n   Enter filename with PORTABLE weight values");
  366.             IO_print(0);
  367.             sprintf(IO_str, " (default=%s): ", net_pwt);
  368.             IO_print(0);
  369.             IO_my_get_string(file2);
  370.             if (file2[0] == ENDSTRING) 
  371.                strcpy(file2, net_pwt);
  372.  
  373.             /*-------------------------*/
  374.             /* call the create routine */
  375.             /*-------------------------*/
  376.             CC_create_delivery(the_net, filename, file2);
  377.          ENDIF
  378.          break;
  379.       ENDCASE
  380.       case 'i': BEGIN
  381.          if (check_net_ptr(the_net) == OK) BEGIN
  382.             IO_get_io_name();
  383.             PA_setup_iopairs(the_net, net_iop);
  384.          ENDIF
  385.          break;
  386.       ENDCASE
  387.       case 'j': BEGIN
  388.          if (check_net_ptr(the_net) == OK) BEGIN
  389.             sprintf(IO_str, "\n   Resetting I/O pairs to 'workfile.net'");
  390.             IO_print(0);
  391.             PA_reset_iopairs(the_net, "workfile.net");
  392.          ENDIF
  393.          break;
  394.       ENDCASE
  395.       case 'l': BEGIN
  396.          if (check_net_ptr(the_net) == OK) BEGIN
  397.             sprintf(IO_str, "\n   Layer number: ");
  398.             IO_print(0);
  399.             t1 = IO_my_get_int();
  400.             L_modify_learning(t1, N_get_layer(the_net, t1));
  401.          ENDIF
  402.          break;
  403.       ENDCASE
  404.       case 'm' : BEGIN
  405.          print_menu();
  406.          break;
  407.       ENDCASE
  408.       case 'n': BEGIN
  409.          if(check_net_ptr(the_net) == OK)
  410.             S_show_net(the_net);
  411.          break;
  412.       ENDCASE
  413.       case 'o': BEGIN
  414.          if(check_net_ptr(the_net) == OK)
  415.             if (the_net->num_io_pairs <= 0) BEGIN
  416.                sprintf(IO_str, "\n\n*** no valid set of io pairs to randomize ***\n");
  417.                IO_print(0);
  418.             ENDIF
  419.             else BEGIN
  420.                sprintf(IO_str, "\n   Enter filename for reordered IO pairs: ");
  421.                IO_print(0);
  422.                IO_my_get_string(filename);
  423.                while (filename[0] == ENDSTRING) BEGIN
  424.                   sprintf(IO_str, "\n      try again: ");
  425.                   IO_print(0);
  426.                   IO_my_get_string(filename);
  427.                ENDWHILE
  428.                PA_randomize_file(filename, "workfile.net", the_net->num_io_pairs,
  429.                                  the_net->num_inputs + the_net->num_outputs);
  430.             ENDELSE
  431.  
  432.          break;
  433.       ENDCASE
  434.       case 'p': BEGIN
  435.          if (check_net_ptr(the_net) == OK) BEGIN
  436.             /*----------------------------------*/
  437.             /* Get the input file/screen source */
  438.             /* and number of IO pairs to prop.  */
  439.             /*----------------------------------*/
  440.             sprintf(IO_str, "\n   Enter filename with INPUT(default=manual entry): ");
  441.             IO_print(0);
  442.             IO_my_get_string(filename);
  443.             
  444.             if (filename[0] != ENDSTRING) BEGIN
  445.                sprintf(IO_str, "\n   Enter number of I/O pairs to propagate (default=all): ");
  446.                IO_print(0);
  447.                t1 = IO_get_default_int(-1);
  448.             ENDIF
  449.             
  450.             /*----------------------------------------*/
  451.             /* Get the output file/screen destination */
  452.             /*----------------------------------------*/
  453.             sprintf(IO_str, "\n   Enter OUTPUT destination file(default=screen): ");
  454.             IO_print(0);
  455.             IO_my_get_string(file2);
  456.             if (file2[0] == ENDSTRING)
  457.                fp = NULL;
  458.             else 
  459.                fp = fopen(file2, "wt");
  460.             
  461.             /*-------------------------------*/
  462.             /* call the query function, then */
  463.             /* tidy up output file if used   */
  464.             /*-------------------------------*/
  465.             N_query_net(the_net, filename, fp, t1);
  466.             if (fp != NULL) fclose(fp);
  467.          ENDIF
  468.          break;
  469.       ENDCASE
  470.       case 'r': BEGIN
  471.          if (check_net_ptr(the_net) == OK) BEGIN
  472.             sprintf(IO_str,"\n   Load weights from FAST or PORTABLE format");
  473.             IO_print(0);
  474.             sprintf(IO_str, "(f/p, default=f)? ");
  475.             IO_print(0);
  476.             IO_my_get_string(tmp_str);
  477.             if ((tmp_str[0] == 'p') || (tmp_str[0] == 'P'))
  478.                t1 = PORTABLE_FORMAT;
  479.             else t1 = FAST_FORMAT;
  480.             sprintf(IO_str, "\n   Enter name of file with weight values");
  481.             IO_print(0);
  482.             IO_get_wts_name(t1);
  483.             t1 = ( (t1 == FAST_FORMAT) 
  484.                    ? N_reset_wts(the_net, net_fwt, t1)
  485.                    : N_reset_wts(the_net, net_pwt, t1) );
  486.             if (t1 == ERROR) BEGIN
  487.                sprintf(IO_str, "\n*** weight resetting was incomplete ***");
  488.                IO_print(0);
  489.             ENDIF
  490.          ENDIF
  491.          break;
  492.       ENDCASE
  493.       case 's': BEGIN
  494.          if (check_net_ptr(the_net) == OK) BEGIN
  495.             sprintf(IO_str,"\n   Save weights in FAST or PORTABLE format");
  496.             IO_print(0);
  497.             sprintf(IO_str, "(f/p, default=f)? ");
  498.             IO_print(0);
  499.             IO_my_get_string(tmp_str);
  500.             if ((tmp_str[0] == 'p') || (tmp_str[0] == 'P'))
  501.                t1 = PORTABLE_FORMAT;
  502.             else t1 = FAST_FORMAT;
  503.             sprintf(IO_str, "\n   Enter file name for storing weights");
  504.             IO_print(0);
  505.             IO_get_wts_name(t1);
  506.             if (t1 == FAST_FORMAT)
  507.                N_save_wts(the_net, net_fwt, t1);
  508.             else N_save_wts(the_net, net_pwt, t1);
  509.          ENDIF
  510.          break;
  511.       ENDCASE
  512.       case 't': BEGIN
  513.          if(check_net_ptr(the_net) == OK) BEGIN
  514.             if (the_net->num_io_pairs <= 0) BEGIN
  515.                sprintf(IO_str, "\n\n*** no valid set of io pairs to teach ***\n");
  516.                IO_print(0);
  517.             ENDIF
  518.             else BEGIN
  519.                sprintf(IO_str, "\n   Enter constraint error: ");
  520.                IO_print(0);
  521.                tf1 = IO_my_get_float();
  522.                sprintf(IO_str, "\n   Enter max number of cycles(default=%d): ", MAX_CYCLES);
  523.                IO_print(0);
  524.                t1  = IO_get_num_cycles();
  525.                sprintf(IO_str, "\n   Enter cycle increment for showing errors");
  526.                IO_print(0);
  527.                sprintf(IO_str, "(default=1): ");
  528.                IO_print(0);
  529.                t2  = IO_get_default_int(1);
  530.                T_teach_net(the_net, C_float_to_Sint(tf1), t1, t2);
  531.             ENDELSE
  532.          ENDIF
  533.          break;
  534.       ENDCASE
  535.       case 'w': BEGIN
  536.          if (check_net_ptr(the_net) == OK) BEGIN
  537.             sprintf(IO_str, "\n   source layer? ");
  538.             IO_print(0);
  539.             t1 = IO_my_get_int();
  540.             sprintf(IO_str, "\n   target layer? ");
  541.             IO_print(0);
  542.             t2 = IO_my_get_int();
  543.             S_show_weights(the_net, t1, t2);
  544.          ENDIF
  545.          break;
  546.       ENDCASE
  547.       default :
  548.          break;
  549.    ENDSWITCH
  550.  
  551. END /* evaluate */
  552.  
  553.  
  554. int  check_net_ptr(ptr_net)
  555. Net  *ptr_net;
  556. /*
  557. ----------------------------------------------------------------------
  558.  This small routine is simply a check to see whether or not the arg   
  559.   passed in is actually a valid Net pointer.  If a net has not been   
  560.   created, or if there was an error in its creation, then no          
  561.   operations ought to be allowed using the net pointer.               
  562.  Returns 0 if the ptr is INvalid, otherwise returns a 1.  Note that   
  563.   a net which has an error will return a non-NULL pointer, but the ID 
  564.   of such a net will be negative. (see return values of B_create_net) 
  565. ----------------------------------------------------------------------
  566. */
  567. BEGIN
  568.    if (ptr_net == NULL) BEGIN
  569.       sprintf(IO_str, "\n\n*** no valid net exists ***\n");
  570.       IO_print(0);
  571.       return(0);
  572.    ENDIF
  573.    if (ptr_net->ID == ERROR) BEGIN
  574.       sprintf(IO_str, "\n\n*** no valid net exists ***\n");
  575.       IO_print(0);
  576.       return(0);
  577.    ENDIF
  578.    return(OK);
  579.  
  580. END /* check_net_ptr */
  581.  
  582.  
  583. void  cleanup()
  584. /*
  585. ----------------------------------------------------------------------
  586.    Before quitting the program entirely, this routine ensures that all
  587.    the loose ends are accounted for. Right now, that consists of freeing
  588.    all the memory currently used for the network.
  589. ----------------------------------------------------------------------
  590. */
  591. BEGIN
  592.  
  593.    the_net = B_free_net(the_net);
  594.    
  595. END /* cleanup */
  596.  
  597.